import torch
import torch.nn.functional as F
from gymnasium.spaces import Discrete


class ActionDecoder(torch.nn.Module):
    def __init__(self, embedding_size, num_action):
        super().__init__()
        self.embedding_size = embedding_size
        self.num_action = num_action
        # self.decode_layers = torch.nn.Sequential(
        #     torch.nn.Linear(embedding_size, 64),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(64, num_action),
        # )
        self.decode_layers = torch.nn.Linear(embedding_size, num_action)
        self.fisher_information = None
        self.optimal_params = {}

    def forward(self, embedding, action_space: Discrete = None):
        policy_logits = self.decode_layers(embedding)
        if action_space is None:
            return policy_logits, None
        # 只取符合当前动作空间大小的部分
        # TODO 使用mask控制输出的动作范围，而不是使用截断
        action_size = action_space.n
        if action_size < policy_logits.shape[-1]:
            policy_logits_subset = policy_logits[:, :action_size]
        else:
            policy_logits_subset = policy_logits
        return policy_logits_subset, policy_logits

    def expand_output(self, new_num_action, fisher_dataset):
        # 扩展输出层，用于适应新的动作空间大小
        if new_num_action <= self.num_action:
            return
        # self.compute_fisher_information(fisher_dataset)  # 扩展之前先计算上一个任务的Fisher信息

        new_decode_layers = torch.nn.Linear(self.embedding_size, new_num_action)
        with torch.no_grad():
            new_decode_layers.weight[:self.num_action] = self.decode_layers.weight
            new_decode_layers.bias[:self.num_action] = self.decode_layers.bias
        self.decode_layers = new_decode_layers
        self.num_action = new_num_action

    def compute_fisher_information(self, data_set, device='cuda'):
        # 计算Fisher信息矩阵
        self.eval()
        fisher_information = None
        data_loader = torch.utils.data.DataLoader(data_set, batch_size=512, shuffle=True)
        for embeddings, actions in data_loader:
            embeddings, actions = embeddings.to(device), actions.to(device)
            self.zero_grad()
            outputs, _ = self.forward(embeddings)
            log_probs = F.log_softmax(outputs, dim=-1)
            action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze()
            loss = -action_log_probs.mean()
            loss.backward(retain_graph=True)

            params = {n: p for n, p in self.named_parameters() if p.requires_grad}
            if fisher_information is None:
                fisher_information = {n: p.grad.data.clone().pow(2) for n, p in params.items()}
            else:
                for n, p in params.items():
                    fisher_information[n] += p.grad.data.clone().pow(2)

        for n in fisher_information:
            fisher_information[n] /= len(data_loader)

        self.fisher_information = fisher_information
        self.optimal_params = {n: p.clone() for n, p in params.items()}

    def ewc_loss(self):
        if self.fisher_information is None or self.optimal_params is None:
            return 0.0
        loss = 0.0
        for n, p in self.named_parameters():
            if n in self.fisher_information:
                fisher = self.fisher_information[n]
                optimal_param = self.optimal_params[n]

                # 当新旧任务的参数大小不一致时，只取最小的部分计算EWC损失
                if p.shape != optimal_param.shape:
                    min_size = min(p.shape[0], optimal_param.shape[0])
                    fisher = fisher[:min_size]
                    p = p[:min_size]
                    optimal_param = optimal_param[:min_size]

                loss += (fisher * (p - optimal_param).pow(2)).sum()
        return loss

    def copy_for_inference(self):
        # 复制一个用于推断的模型，不包含Fisher信息和优化参数
        new_model = ActionDecoder(self.embedding_size, self.num_action)
        # Copy the state_dict (parameters) from the current model to the new model
        new_model.load_state_dict(self.state_dict())
        return new_model